# 通过yolohead获得预测结果

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from PANet import panet  # 导入panet加强特征提取方法
from CSPDarknet53 import conv_block  # 导入标准卷积快

# 对PANet的特征输出层处理获得最终的预测结果
def yoloHead(inputs, num_anchors, num_classes):
    '''
    num_anchors每个网格包含先验框的数量, num_classes分类数
    num_anchors(5+num_classes)代表: 每个先验框有5+num_classes个参数, 即(x,y,w,h,c)和20个类别的条件概率
    每一个特征层的输出代表: 每一个网格上每一个先验框内部是否包含物体, 以及包含物体的种类, 和先验框的调整参数
    '''

    # 获得三个有效特征层
    p3_output, p4_output, p5_output = panet(inputs)

    # 3*3卷积[52,52,128]==>[52,52,256]
    p3_output = conv_block(p3_output, filters=256, kernel_size=(3,3), strides=1)
    # [52,52,256]==>[52,52,num_anchors(5+num_classes)]
    p3_output = conv_block(p3_output, filters=num_anchors*(5+num_classes),
                           kernel_size=(1,1), strides=1)

    # [26,26,256]==>[26,26,516]
    p4_output = conv_block(p4_output, filters=512, kernel_size=(3,3), strides=1)
    # [26,26,512]==>[26,26,num_anchors(5+num_classes)]
    p4_output = conv_block(p4_output, filters=num_anchors*(5+num_classes),
                           kernel_size=(1,1), strides=1)

    # [13,13,512]==>[13,13,1024]
    p5_output = conv_block(p5_output, filters=1024, kernel_size=(3,3), strides=1)
    # [13,13,1024]==>[13,13,num_anchors(5+num_classes)]
    p5_output = conv_block(p5_output, filters=num_anchors*(5+num_classes),
                           kernel_size=(1,1), strides=1)
    
    # 构建模型
    model = Model(inputs, [p5_output, p4_output, p3_output])

    return model

# 查看模型结构
if __name__ == '__main__':

    inputs = keras.Input(shape=[416,416,3])  # 构造输入
    # 接收模型，传入先验框数量3，分类数20
    model = yoloHead(inputs, num_anchors=3, num_classes=20)
    # 网络架构
    model.summary()